{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "第5讲 *k*近邻法\n",
    "===\n",
    "主讲教师：高鹏\n",
    "---\n",
    "办公地点：网络空间安全学院407\n",
    "---\n",
    "联系方式：pgao@qfnu.edu.cn\n",
    "---\n",
    "面向专业：软件工程（智能数据）\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 兼容python2和python3\n",
    "from __future__ import print_function\n",
    "\n",
    "import numpy as np\n",
    "import scipy as sp\n",
    "import pandas as pd\n",
    "import sklearn\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 引言\n",
    "\n",
    "$k$近邻法（$k$-nearestneighbor，$k$-NN）是一种基本分类与回归方法。我们只讨论分类问题中的太近邻法。k近邻法的输入为实例的特征向量，对应于特征空间的点；输出为实例的类别，可以取多类。k近邻法假设给定一个训练数据集，其中的实例类别已定。分类时，对新的实例，根据其k个最近邻的训练实例的类别，通过多数表决等方式进行预测。因此，k近邻法不具有显式的学习过程，k近邻法实际上利用训练数据集对特征向量空间进行划分，并作为其分类的\"模型\"。<font color=#ff0000>k值的选择</font>、<font color=#ff0000>距离度量</font>及<font color=#ff0000>分类决策规则</font>是k近邻法的三个基本要素，k近邻法1968年由Cover和Hart提出。\n",
    "\n",
    "本讲首先叙述k近邻算法，然后讨论k近邻法的模型及三个基本要素，最后讲述k近邻法的一个实现方法—kd树，介绍构造kd树和搜索kd树的算法。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# k近邻算法\n",
    "\n",
    "k近邻算法简单、直观：给定一个训练数据集，对新的输入实例，在训练数据集中找到与该实例最邻近的k个实例，这k个实例的多数属于某个类，就把该输入实例分为这个类。下面先叙述k近邻算法，然后再讨论其细节。\n",
    "\n",
    "**算法1（k近邻法）**\n",
    "\n",
    "输入：训练数据集\n",
    "\n",
    "$$\n",
    "T=\\{(x_1,y_1),(x_2,y_2),\\ldots,(x_N,y_N)\\}\n",
    "$$\n",
    "\n",
    "其中，$x_i\\in\\mathcal{X}=\\mathbb{R}^n$为实例的特征向量，$y_i\\in\\mathcal{Y}=\\{c_1,c_2,\\ldots,c_K\\}$为实例的类别，$i=1,2,\\ldots,N$；实例特征向量$x$；\n",
    "\n",
    "输出：实例$x$所属的类$y$。\n",
    "\n",
    "(1) 根据给定的距离度量，在训练集T中找出与x最邻近的k个点，涵盖这k个点的x的邻域记作 N，（x）;\n",
    "\n",
    "(2) 在$N_k(x)$中根据分类决策规则（如多数表决）决定$x$的类别$y$\n",
    "\n",
    "$$\n",
    "y=\\arg\\max_{c_j}\\sum_{x_i\\in N_k(x)}I(y_i=c_j),\\quad i=1,2,\\ldots,N,\\quad j=1,2,\\ldots,K\n",
    "$$\n",
    "\n",
    "式中，$I$为指示函数，即当$y_i=c_i$时$I$为1，否则$I$为0。\n",
    "\n",
    "k近邻法的特殊情况是$k=1$的情形，称为最近邻算法。对于输入的实例点（特征向量）$x$，最近邻法将训练数据集中与$x$最邻近点的类作为$x$的类。\n",
    "\n",
    "k近邻法没有显式的学习过程。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# k近邻模型\n",
    "\n",
    "\n",
    "k近邻法使用的模型实际上对应于对特征空间的划分。模型由三个基本要素—距离度量、k值的选择和分类决策规则决定。\n",
    "\n",
    "## 模型\n",
    "\n",
    "k近邻法中，当训练集、距离度量（如欧氏距离）、k值及分类决策规则（如多数表决）确定后，对于任何一个新的输入实例，它所属的类唯一地确定。这相当于根据上述要素将特征空间划分为一些子空间，确定子空间里的每个点所属的类。这一事实从最近邻算法中可以看得很清楚。\n",
    "\n",
    "特征空间中，对每个训练实例点$x$，距离该点比其他点更近的所有点组成一个区域，叫作单元（cell）。每个训练实例点拥有一个单元，所有训练实例点的单元构成对特征空间的一个划分。最近邻法将实例$x_i$的类$y_i$，作为其单元中所有点的类标记（class label）。这样，每个单元的实例点的类别是确定的。下图是二维特征空间划分的一个例子。\n",
    "\n",
    "<p align=\"center\">\n",
    "  <img width=\"400\" src=\"Lesson5-1.jpg\">\n",
    "</p>\n",
    "\n",
    "## 距离度量\n",
    "\n",
    "特征空间中两个实例点的距离是两个实例点相似程度的反映。k近邻模型的特征空间一般是$n$维实数向量空间$\\mathbb{R}^n$。使用的距离是欧氏距离，但也可以是其他距离，如更一般的$L_p$距离（$L_p$ distance）或Minkowski距离（Minkowki distance）。\n",
    "\n",
    "设特征空间$\\mathcal{X}$是$n$维实数向量空间$\\mathbb{R}^n$，$x_i,x_j\\in\\mathcal{X}$，$x_i=(x_i^{(1)},x_i^{(1)},\\ldots,x_i^{(n)})^T$，$x_j=(x_j^{(1)},x_j^{(1)},\\ldots,x_j^{(n)})^T$，$x_i,x_j$的$L_p$距离定义为\n",
    "\n",
    "$$\n",
    "L_p(x_i,x_j)=(\\sum^n_{l=1}|x_i^{(l)}-x_j^{(l)}|^p)^{\\frac{1}{p}}\n",
    "$$\n",
    "\n",
    "这里$p\\geq 1$。\n",
    "\n",
    "当$p=2$时，称为欧式距离（Euclidean distance），即\n",
    "\n",
    "$$\n",
    "L_2(x_i,x_j)=(\\Sigma^n_{l=1}|x_i^{(l)}-x_j^{(l)}|^2)^{\\frac{1}{2}}\n",
    "$$\n",
    "\n",
    "当$p=1$时，称为曼哈顿距离（Manhattan distance），即\n",
    "\n",
    "$$\n",
    "L_1(x_i,x_j)=\\sum^n_{l=1}|x_i^{(l)}-x_j^{(l)}|\n",
    "$$\n",
    "\n",
    "当$p=\\infty$时，是各个坐标距离的最大值，即\n",
    "\n",
    "$$\n",
    "L_\\infty(x_i,x_j)=\\max_l|x_i^{(l)}-x_j^{(l)}|\n",
    "$$\n",
    "\n",
    "下图给出了二维空间中$p$取不同值时，与原点的$L_p$距离为1（$L_p=1$）的点的图形。\n",
    "\n",
    "<p align=\"center\">\n",
    "  <img width=\"400\" src=\"Lesson5-2.jpg\">\n",
    "</p>\n",
    "\n",
    "## 例\n",
    "\n",
    "已知二维空间的3个点$x_1=(1,1)^T$，$x_2=(5,1)^T$，$x_3=(4,4)^T$，试求在$p$取不同值时，$L_p$距离下$x_1$的最近邻点。\n",
    "\n",
    "**解**  因为$x_1$和$x_2$只有第二维上值不同，所以$p$为任何值时，$L_p(x_1,x_2)=4$。而\n",
    "\n",
    "$$\n",
    "L_1(x_1,x_3)=6，L_2(x_1,x_3)=4.24，L_3(x_1,x_3)=3.78，L_4(x_1,x_3)=3.57\n",
    "$$\n",
    "\n",
    "于是得到：$p$等于1或者2时，$x_2$是$x_1$的最近邻点；$p$大于等于3时，$x_3$是$x_1$的最近邻点。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "from itertools import combinations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x1 = [1, 1]\n",
    "x2 = [5, 1]\n",
    "x3 = [4, 4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def L(x, y, p=2):\n",
    "    if len(x) == len(y) and len(x) > 1:\n",
    "        sum = 0\n",
    "        for i in range(len(x)):\n",
    "            sum += math.pow(abs(x[i] - y[i]), p)\n",
    "        return math.pow(sum, 1 / p)\n",
    "    else:\n",
    "        return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for c in [x2, x3]:\n",
    "    for i in range(1, 5):\n",
    "        r = L(x1, c, p=i)\n",
    "        print('[1, 1] -',c,':',r)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## k值的选择\n",
    "\n",
    "k值的选择会对k近邻法的结果产生重大影响。\n",
    "\n",
    "如果选择较小的k值，就相当于用较小的邻域中的训练实例进行预测，\"学习\"的近似误差（approximationerror）会减小，只有与输入实例较近的（相似的）训练实例才会对预测结果起作用。但缺点是\"学习\"的估计误差（estimationerror）会增大，预测结果会对近邻的实例点非常敏感。如果邻近的实例点恰巧是噪声，预测就会出错。换句话说，k值的减小就意味着整体模型变得复杂，容易发生过拟合。\n",
    "\n",
    "如果选择较大的k值，就相当于用较大邻域中的训练实例进行预测。其优点是可以减少学习的估计误差。但缺点是学习的近似误差会增大。这时与输入实例较远的（不相似的）训练实例也会对预测起作用，使预测发生错误。k值的增大就意味着整体的模型变得简单。\n",
    "\n",
    "如果k=$N$，那么无论输入实例是什么，都将简单地预测它属于在训练实例中最多的类。这时，模型过于简单，完全忽略训练实例中的大量有用信息，是不可取的。\n",
    "\n",
    "在应用中，k值一般取一个比较小的数值。通常采用交叉验证法来选取最优的k值。\n",
    "\n",
    "## 分类决策规则\n",
    "\n",
    "k近邻法中的分类决策规则往往是多数表决，即由输入实例的k个邻近的训练实例中的多数类决定输入实例的类。\n",
    "\n",
    "多数表决规则（majority voting rule）有如下解释：如果分类的损失函数为0-1损失函数，分类函数为\n",
    "\n",
    "$$\n",
    "f:\\mathbb{R}^n\\rightarrow\\{c_1,c_2,\\ldots,c_K\\}\n",
    "$$\n",
    "\n",
    "那么误分类的概率是\n",
    "\n",
    "$$\n",
    "P(Y \\neq f(X))=1-P(Y=f(X))\n",
    "$$\n",
    "\n",
    "对给定的实例$x\\in\\mathcal{X}$，其最近邻的k个训练实例点构成集合$N_k(x)$。如果涵盖$N_k(x)$的区域的类别是$c_j$，那么误分类率是\n",
    "\n",
    "$$\n",
    "\\frac{1}{k}\\sum_{x_i\\in N_k(x)}I(y_i\\neq c_j)=1-\\frac{1}{k}\\sum_{x_i\\in N_k(x)}I(y_i= c_j)\n",
    "$$\n",
    "\n",
    "要使误分类率最小即经验风险最小，就要使$\\sum_{x_i\\in N_k(x)}I(y_i=c_j)$最大，所以多数表决规则等价于经验风险最小化。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# k近邻法的实现：kd树\n",
    "\n",
    "实现k近邻法时，主要考虑的问题是如何对训练数据进行快速k近邻搜索。这点在特征空间的维数大及训练数据容量大时尤其必要。\n",
    "\n",
    "k近邻法最简单的实现方法是线性扫描（linear scan）。这时要计算输入实例与每一个训练实例的距离。当训练集很大时，计算非常耗时，这种方法是不可行的。\n",
    "\n",
    "为了提高k近邻搜索的效率，可以考虑使用特殊的结构存储训练数据，以减少计算距离的次数。具体方法很多，下面介绍其中的kd树（kd tree）方法。\n",
    "\n",
    "## 构造kd树\n",
    "\n",
    "kd树是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。kd 树是二叉树，表示对k维空间的一个划分（partition）。构造kd树相当于不断地用垂直于坐标轴的超平面将k维空间切分，构成一系列的k维超矩形区域，kd树的每个结点对应于一个k维超矩形区域。\n",
    "\n",
    "构造kd树的方法如下∶构造根结点，使根结点对应于k维空间中包含所有实例点的超矩形区域；通过下面的递归方法，不断地对k维空间进行切分，生成子结点。在超矩形区域（结点）上选择一个坐标轴和在此坐标轴上的一个切分点，确定一个超平面，这个超平面通过选定的切分点并垂直于选定的坐标轴，将当前超矩形区域切分为左右两个子区域（子结点）；这时，实例被分到两个子区域。这个过程直到子区域内没有实例时终止（终止时的结点为叶结点）。在此过程中，将实例保存在相应的结点上。\n",
    "\n",
    "通常，依次选择坐标轴对空间切分，选择训练实例点在选定坐标轴上的中位数（median）为切分点，这样得到的kd树是平衡的。注意，平衡的kd树搜索时的效率未必是最优的。\n",
    "\n",
    "下面给出构造kd树的算法。\n",
    "\n",
    "**算法2（构造平衡kd树）**\n",
    "\n",
    "输入：k维空间数据集$T=\\{x_1,x_2,\\ldots,x_N\\}$，其中$x_i=(x_i^{(1)},x_i^{(2)},\\ldots,x_i^{(k)})^T$，$i=1,2,\\ldots,N$\n",
    "\n",
    "输出：kd树\n",
    "\n",
    "(1) 开始：构造根结点，根结点对应于包含T的k维空间的超矩形区域。\n",
    "\n",
    "选择$x^{(1)}$为坐标轴，以T中所有实例的$x^{(1)}$坐标的中位数为切分点，将根结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴$x^{(1)}$垂直的超平面实现。\n",
    "\n",
    "由根结点生成深度为1的左、右子结点：左子结点对应坐标$x^{(1)}$小于切分点的子区域，右子结点对应于坐标$x^{(1)}$大于切分点的子区域。\n",
    "\n",
    "将落在切分超平面上的实例点保存在根结点。\n",
    "\n",
    "(2) 重复：对深度为$j$的结点，选择$x^{(l)}$为切分的坐标轴，$l=j(\\operatorname{mod}k)+1$，以该结点的区域中所有实例的$x^{(l)}$坐标的中位数为切分点，将该结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴$x^{(l)}$垂直的超平面实现。\n",
    "\n",
    "由该结点生成深度为$j+1$的左、右子结点：左子结点对应坐标$x^{(l)}$小于切分点的子区域，右子结点对应坐标$x^{(l)}$大于切分点的子区域。\n",
    "\n",
    "将落在切分超平面上的实例点保存在该结点。\n",
    "\n",
    "(3) 直到两个子区域没有实例存在时停止。从而形成kd树的区域划分。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# kd-tree每个结点中主要包含的数据结构如下\n",
    "# 节点定义\n",
    "class KdNode(object):\n",
    "    def __init__(self, dom_elt, split, left, right):\n",
    "        self.dom_elt = dom_elt  # k维向量节点(k维空间中的一个样本点)\n",
    "        self.split = split  # 整数（进行分割维度的序号）\n",
    "        self.left = left  # 该结点分割超平面左子空间构成的kd-tree\n",
    "        self.right = right  # 该结点分割超平面右子空间构成的kd-tree\n",
    "\n",
    "# Kd数\n",
    "class KdTree(object):\n",
    "    def __init__(self, data):\n",
    "        k = len(data[0])  # 数据维度\n",
    "        \n",
    "        # Kd数创建\n",
    "        def CreateNode(split, data_set):  # 按第split维划分数据集exset创建KdNode\n",
    "            if not data_set:  # 数据集为空\n",
    "                return None\n",
    "            # key参数的值为一个函数，此函数只有一个参数且返回一个值用来进行比较\n",
    "            # operator模块提供的itemgetter函数用于获取对象的哪些维的数据，参数为需要获取的数据在对象中的序号\n",
    "            #data_set.sort(key=itemgetter(split)) # 按要进行分割的那一维数据排序\n",
    "            data_set.sort(key=lambda x: x[split])\n",
    "            split_pos = len(data_set) // 2  # //为Python中的整数除法\n",
    "            median = data_set[split_pos]  # 中位数分割点\n",
    "            split_next = (split + 1) % k  # cycle coordinates\n",
    "\n",
    "            # 递归的创建kd树\n",
    "            return KdNode(\n",
    "                median,\n",
    "                split,\n",
    "                CreateNode(split_next, data_set[:split_pos]),  # 创建左子树\n",
    "                CreateNode(split_next, data_set[split_pos + 1:]))  # 创建右子树\n",
    "\n",
    "        self.root = CreateNode(0, data)  # 从第0维分量开始构建kd树,返回根节点\n",
    "\n",
    "\n",
    "# KDTree的前序遍历\n",
    "def preorder(root):\n",
    "    print(root.dom_elt)\n",
    "    if root.left:  # 节点不为空\n",
    "        preorder(root.left)\n",
    "    if root.right:\n",
    "        preorder(root.right)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 对构建好的kd树进行搜索，寻找与目标点最近的样本点：\n",
    "from math import sqrt\n",
    "from collections import namedtuple\n",
    "\n",
    "# 定义一个namedtuple,分别存放最近坐标点、最近距离和访问过的节点数\n",
    "result = namedtuple(\"Result_tuple\",\n",
    "                    \"nearest_point  nearest_dist  nodes_visited\")\n",
    "\n",
    "\n",
    "def find_nearest(tree, point):\n",
    "    k = len(point)  # 数据维度\n",
    "\n",
    "    def travel(kd_node, target, max_dist):\n",
    "        if kd_node is None:\n",
    "            return result([0] * k, float(\"inf\"), 0)  # python中用float(\"inf\")和float(\"-inf\")表示正负无穷\n",
    "\n",
    "        nodes_visited = 1\n",
    "\n",
    "        s = kd_node.split  # 进行分割的维度\n",
    "        pivot = kd_node.dom_elt  # 进行分割的“轴”\n",
    "\n",
    "        if target[s] <= pivot[s]:  # 如果目标点第s维小于分割轴的对应值(目标离左子树更近)\n",
    "            nearer_node = kd_node.left  # 下一个访问节点为左子树根节点\n",
    "            further_node = kd_node.right  # 同时记录下右子树\n",
    "        else:  # 目标离右子树更近\n",
    "            nearer_node = kd_node.right  # 下一个访问节点为右子树根节点\n",
    "            further_node = kd_node.left\n",
    "\n",
    "        temp1 = travel(nearer_node, target, max_dist)  # 进行遍历找到包含目标点的区域\n",
    "\n",
    "        nearest = temp1.nearest_point  # 以此叶结点作为“当前最近点”\n",
    "        dist = temp1.nearest_dist  # 更新最近距离\n",
    "\n",
    "        nodes_visited += temp1.nodes_visited\n",
    "\n",
    "        if dist < max_dist:\n",
    "            max_dist = dist  # 最近点将在以目标点为球心，max_dist为半径的超球体内\n",
    "\n",
    "        temp_dist = abs(pivot[s] - target[s])  # 第s维上目标点与分割超平面的距离\n",
    "        if max_dist < temp_dist:  # 判断超球体是否与超平面相交\n",
    "            return result(nearest, dist, nodes_visited)  # 不相交则可以直接返回，不用继续判断\n",
    "\n",
    "        #----------------------------------------------------------------------\n",
    "        # 计算目标点与分割点的欧氏距离\n",
    "        temp_dist = sqrt(sum((p1 - p2)**2 for p1, p2 in zip(pivot, target)))\n",
    "\n",
    "        if temp_dist < dist:  # 如果“更近”\n",
    "            nearest = pivot  # 更新最近点\n",
    "            dist = temp_dist  # 更新最近距离\n",
    "            max_dist = dist  # 更新超球体半径\n",
    "\n",
    "        # 检查另一个子结点对应的区域是否有更近的点\n",
    "        temp2 = travel(further_node, target, max_dist)\n",
    "\n",
    "        nodes_visited += temp2.nodes_visited\n",
    "        if temp2.nearest_dist < dist:  # 如果另一个子结点内存在更近距离\n",
    "            nearest = temp2.nearest_point  # 更新最近点\n",
    "            dist = temp2.nearest_dist  # 更新最近距离\n",
    "\n",
    "        return result(nearest, dist, nodes_visited)\n",
    "\n",
    "    return travel(tree.root, point, float(\"inf\"))  # 从根节点开始递归"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 例\n",
    "给定一个二维空间的数据集：\n",
    "\n",
    "$$\n",
    "T=\\{(2,3)^T,(5,4)^T,(9,6)^T,(4,7)T,(8,1)^T,(7,2)^T\\}\n",
    "$$\n",
    "\n",
    "构造—个平衡kd 树。\n",
    "\n",
    "**解**  根结点对应包含数据集$T$的矩形，选择$x^{(1)}$轴，6个数据点的$x^{(1)}$坐标的中位数是7，以平面$x^{(1)}=7$将空间分为左、右两个子矩形（子结点）；接着，左矩形以$x^{(2)}=4$分为两个子矩形，右矩形以$x^{(2)}=6$分为两个子矩形，如此递归，最后得到如图所示的特征空间划分和kd树。\n",
    "\n",
    "<p align=\"center\">\n",
    "  <img width=\"300\" src=\"Lesson5-3.jpg\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]\n",
    "kd = KdTree(data)\n",
    "preorder(kd.root)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p align=\"center\">\n",
    "  <img width=\"400\" src=\"Lesson5-4.jpg\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 搜索kd树\n",
    "\n",
    "下面介绍如何利用k树进行k近邻搜索。可以看到，利用kd树可以省去对大部分数据点的搜索，从而减少搜索的计算量。这里以最近邻为例加以叙述，同样的方法可以应用到k近邻。\n",
    "\n",
    "给定一个目标点，搜索其最近邻.首先找到包含目标点的叶结点；然后从该叶结点出发，依次回退到父结点；不断查找与目标点最邻近的结点，当确定不可能存在更近的结点时终止，这样搜索就被限制在空间的局部区域上，效率大为提高。\n",
    "\n",
    "包含目标点的叶结点对应包含目标点的最小超矩形区域。以此叶结点的实例点作为当前最近点。目标点的最近邻一定在以目标点为中心并通过当前最近点的超球体的内部。然后返回当前结点的父结点，如果父结点的另一子结点的超矩形区域与超球体相交，那么在相交的区域内寻找与目标点更近的实例点。如果存在这样的点，将此点作为新的当前最近点。算法转到更上一级的父结点，继续上述过程。如果父结点的另一子结点的超矩形区域与超球体不相交，或不存在比当前最近点更近的点，则停止搜索。\n",
    "\n",
    "下面叙述用kd树的最近邻搜索算法。\n",
    "\n",
    "**算法3（用 kd 树的最近邻搜索）**\n",
    "\n",
    "输入：已构造的kd树；目标点$x$\n",
    "\n",
    "输出：$x$的最近邻\n",
    "\n",
    "(1) 在kd树中找出包含目标点$x$的叶结点：从根结点出发，递归地向下访问 kd树。若目标点x当前维的坐标小于切分点的坐标，则移动到左子结点，否则移动到右子结点。直到子结点为叶结点为止。\n",
    "\n",
    "(2) 以此叶结点为\"当前最近点\"。\n",
    "\n",
    "(3) 递归地向上回退，在每个结点进行以下操作：\n",
    "(a) 如果该结点保存的实例点比当前最近点距离目标点更近，则以该实例点为\"当前最近点\"。\n",
    "\n",
    "(b) 当前最近点一定存在于该结点一个子结点对应的区域。检查该子结点的父结点的另一子结点对应的区域是否有更近的点。具体地，检查另一子结点对应的区域是否与以目标点为球心、以目标点与\"当前最近点\"间的距离为半径的超球体相交。\n",
    "\n",
    "如果相交，可能在另一个子结点对应的区域内存在距目标点更近的点，移动到另一个子结点。接着，递归地进行最近邻搜索；如果不相交，向上回退。\n",
    "\n",
    "(4) 当回退到根结点时，搜索结束。最后的\"当前最近点\"即为x的最近邻点。\n",
    "\n",
    "如果实例点是随机分布的，kd树搜索的平均计算复杂度是$\\mathcal{O}(\\log N)$，这里$N$是训练实例数。kd树更适用于训练实例数远大于空间维数时的k近邻搜索。当空间维数接近训练实例数时，它的效率会迅速下降，几乎接近线性扫描。\n",
    "\n",
    "下面通过一个例题来说明搜索方法。\n",
    "\n",
    "## 例\n",
    "\n",
    "给定一个如图所示的kd树，根结点为A，其子结点为B、C等。树上共存储7个实例点；另有一个输入目标实例点S，求S的最近邻。\n",
    "\n",
    "**解**  首先在kd树中找到包含点S的叶结点D（图中的右下区域），以点D作为近似最近邻。真正最近邻一定在以点S为中心通过点D的圆的内部。然后返回结点D的父结点B，在结点B的另一子结点F的区域内搜索最近邻。结点尸的区域与圆不相交，不可能有最近邻点，继续返回上一级父结点A，在结点A的另一子结点C的区域内搜索最近邻。结点C的区域与圆相交；该区域在圆内的实例点有点E，点E比点D更近，成为新的最近邻近似。最后得到点E是点S的最近邻。\n",
    "\n",
    "<p align=\"center\">\n",
    "  <img width=\"400\" src=\"Lesson5-5.jpg\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 例\n",
    "\n",
    "在二维空间中给出实例点，画出$k$为1和2时的$k$近邻法构成的空间划分，并对其进行比较，体会$k$值选择与模型复杂度及预测准确率的关系。\n",
    "\n",
    "**解**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from matplotlib.colors import ListedColormap\n",
    "\n",
    "data = np.array([[5, 12, 1], [6, 21, 0], [14, 5, 0], [16, 10, 0], [13, 19, 0], [13, 32, 1],\n",
    "                 [17, 27, 1], [18, 24, 1], [20, 20, 0], [23, 14, 1],[23, 25, 1], [23, 31, 1],\n",
    "                 [26, 8, 0], [30, 17, 1], [30, 26, 1], [34, 8, 0], [34, 19, 1], [37, 28, 1]])\n",
    "X_train = data[:, 0:2]\n",
    "y_train = data[:, 2]\n",
    "\n",
    "models = (KNeighborsClassifier(n_neighbors=1, n_jobs=-1), KNeighborsClassifier(n_neighbors=2, n_jobs=-1))\n",
    "models = (clf.fit(X_train, y_train) for clf in models)\n",
    "\n",
    "titles = ('K Neighbors with k=1', 'K Neighbors with k=2')\n",
    "\n",
    "fig = plt.figure(figsize=(15, 5))\n",
    "plt.subplots_adjust(wspace=0.4, hspace=0.4)\n",
    "\n",
    "X0, X1 = X_train[:, 0], X_train[:, 1]\n",
    "\n",
    "x_min, x_max = X0.min() - 1, X0.max() + 1\n",
    "y_min, y_max = X1.min() - 1, X1.max() + 1\n",
    "xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.2),\n",
    "                     np.arange(y_min, y_max, 0.2))\n",
    "\n",
    "for clf, title, ax in zip(models, titles, fig.subplots(1, 2).flatten()):\n",
    "    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])\n",
    "    Z = Z.reshape(xx.shape)\n",
    "    colors = ('red', 'green', 'lightgreen', 'gray', 'cyan')\n",
    "    cmap = ListedColormap(colors[:len(np.unique(Z))])\n",
    "    ax.contourf(xx, yy, Z, cmap=cmap, alpha=0.5)\n",
    "    ax.scatter(X0, X1, c=y_train, s=50, edgecolors='k', cmap=cmap, alpha=0.5)\n",
    "    ax.set_title(title)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 例\n",
    "\n",
    "使用$kd$树求点$x=(3,4.5)^T$的最近邻点。\n",
    "\n",
    "**解**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.neighbors import KDTree\n",
    "\n",
    "train_data = np.array([(2, 3), (5, 4), (9, 6), (4, 7), (8, 1), (7, 2)])\n",
    "tree = KDTree(train_data, leaf_size=2)\n",
    "dist, ind = tree.query(np.array([(3, 4.5)]), k=1)\n",
    "x1 = train_data[ind[0]][0][0]\n",
    "x2 = train_data[ind[0]][0][1]\n",
    "\n",
    "print(\"x点的最近邻点是({0}, {1})\".format(x1, x2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 小结\n",
    "\n",
    "1．$k$近邻法是基本且简单的分类与回归方法。$k$近邻法的基本做法是：对给定的训练实例点和输入实例点，首先确定输入实例点的$k$个最近邻训练实例点，然后利用这$k$个训练实例点的类的多数来预测输入实例点的类。\n",
    "\n",
    "2．$k$近邻模型对应于基于训练数据集对特征空间的一个划分。$k$近邻法中，当训练集、距离度量、$k$值及分类决策规则确定后，其结果唯一确定。\n",
    "\n",
    "3．$k$近邻法三要素：距离度量、$k$值的选择和分类决策规则。常用的距离度量是欧氏距离及更一般的$L_p$距离。$k$值小时，$k$近邻模型更复杂；$k$值大时，$k$近邻模型更简单。$k$值的选择反映了对近似误差与估计误差之间的权衡，通常由交叉验证选择最优的$k$。\n",
    "\n",
    "常用的分类决策规则是多数表决，对应于经验风险最小化。\n",
    "\n",
    "4．$k$近邻法的实现需要考虑如何快速搜索k个最近邻点。kd树是一种便于对k维空间中的数据进行快速检索的数据结构。kd树是二叉树，表示对$k$维空间的一个划分，其每个结点对应于$k$维空间划分中的一个超矩形区域。利用kd树可以省去对大部分数据点的搜索，从而减少搜索的计算量。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
